import torch
import torch.nn as nn

class Norm(nn.Module):

    def __init__(self, norm_type, hidden_dim=64, print_info=None):
        super(Norm, self).__init__()
        # assert norm_type in ['bn', 'ln', 'gn', None]
        self.norm = None
        self.print_info = print_info
        if norm_type == 'bn':
            self.norm = nn.BatchNorm1d(hidden_dim, track_running_stats=False)
        elif norm_type == 'gn':
            self.norm = norm_type
            self.weight = nn.Parameter(torch.ones(hidden_dim))
            self.bias = nn.Parameter(torch.zeros(hidden_dim))

            self.mean_scale = nn.Parameter(torch.ones(hidden_dim))

    def forward(self, tensor, print_=False):
        if self.norm is not None and type(self.norm) != str:
            res=self.norm(tensor)
            return res, self.norm.bias, self.norm.weight
        elif self.norm is None:
            return tensor

       
        mean= torch.mean(tensor,0).to(tensor.device)
       
        sub = tensor - mean * self.mean_scale
        sub_m=tensor - mean
       
        std=torch.sqrt((torch.sum(sub_m.pow(2),0)/sub_m.size(dim=0))).to(tensor.device)
       
        result=self.weight * sub / std + self.bias
       
        max_elem, max_ind=torch.max(tensor,0)
        return result, (self.weight * mean * (1-self.mean_scale))/ std + self.bias,(self.weight*max_elem)/std
